import torch
from pytorch_q_network import DeepNetWork

if __name__ == "__main__":
    # 定义样例数据+网络
    data = torch.randn(1, 4, 80, 80)
    net = DeepNetWork()
    net.load_state_dict(torch.load("Q_net.pth"))
    # 导出为onnx格式
    torch.onnx.export(
        net,
        data,
        'model.onnx',
        export_params=True,
        opset_version=13,
    )
